{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 1,
   "id": "0dd37832",
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch.distributions import MultivariateNormal\n",
    "import math"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "d6404b9b",
   "metadata": {},
   "source": [
    "## Multivariate Bayesian Inferencing of Mean of Gaussian likelihood, known Precision\n",
    "\n",
    "Previously, we studied the univariate case. Now let us consider the scenario where the data instances are multi-dimensional i.e vectors. This leads us to the multivariate case.\n",
    "\n",
    "In particular the training dataset is \n",
    "$$\n",
    "X \\equiv \\left\\lbrace \n",
    "\\vec{x}^{ \\left( 1 \\right) }, \\vec{x}^{ \\left( 2 \\right) }, \\cdots, \\vec{x}^{ \\left( i \\right) }, \\cdots, \\vec{x}^{ \\left( n \\right) }\n",
    "\\right\\rbrace\n",
    "$$\n",
    "\n",
    "Here, we assume the variance is known (a constant) but the mean of the data is unknown, modeled as a Gaussian random variable. \n",
    "\n",
    "We will express the Gaussian in terms of the precision matrix ${\\Lambda}$, instead of the covariance matrix ${\\Sigma}$ where ${\\Lambda} = {\\Sigma}^{-1}$.\n",
    "\n",
    "Since we assume that the data is Normally distributed:\n",
    "$$\n",
    "p\\left( X \\middle\\vert \\vec{\\mu} \\right) \\propto e^{ -\\frac{1}{2} \\sum_{i=1}^{n} \\left( \\vec{x}^{ \\left( i \\right) } - \\vec{ \\mu } \\right)^{T} {\\Lambda} \\left( \\vec{x}^{ \\left( i \\right) } - \\vec{ \\mu } \\right) }$$\n",
    "\n",
    "The variance is known - hence it is treated as a constant as opposed to a random variable.\n",
    "\n",
    "The mean  $\\vec{\\mu}$  is unknown and is treated as a random variable. This too is assumed to be a Gaussian, with mean  $\\vec{\\mu_{0}}$  and precision matrix $\\Lambda_{0}$ (not to be confused with  $\\vec{\\mu}$  and  $\\Lambda$  - the mean and precision matrix of the data itself ). Hence, the prior is\n",
    "\n",
    "$$p\\left( \\vec{\\mu }\\right) \\propto e^{ -\\frac{1}{2} \\left( \\vec{\\mu }- \\vec{ \\mu_{0} }\\right)^{T} {\\Lambda}_{0} \\left( \\vec{\\mu }- \\vec{ \\mu_{0} }\\right) }\n",
    "$$\n",
    "\n",
    "\n",
    "Using Bayes theorem, the posterior probability is \n",
    "\n",
    "$$\\overbrace{\n",
    "p\\left(\\vec{\\mu} \\middle\\vert X \\right)\n",
    "}^{posterior}\n",
    "=\n",
    "\\overbrace{\n",
    "p\\left( X \\middle\\vert \\vec{\\mu} \\right)\n",
    "}^{likelihood}\n",
    "\\overbrace{\n",
    "p\\left(\\vec{\\mu} \\right)\n",
    "}^{prior}$$\n",
    "\n",
    "The right hand side is the product of two Gaussians, which is a Gaussian itself. Let us denote its mean and precision matrix as $\\vec{\\mu_{n}}$ and $\\Lambda_{n}$.\n",
    "\n",
    "where\n",
    "$$\n",
    "\\begin{align*}\n",
    "&{\\Lambda}_{n} = n {\\Lambda} + {\\Lambda}_{0} \\\\\n",
    "& \\vec{\\mu_{n}} = {\\Lambda}_{n}^{-1} \\left( n{\\Lambda} \\bar{\\vec{x}} + {\\Lambda}_{0} \\vec{\\mu}_{0} \\right)    \n",
    "\\end{align*}$$\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "id": "41bca08a",
   "metadata": {},
   "outputs": [],
   "source": [
    "def inference_known_precision(X, prior_dist, precision_known):\n",
    "    mu_mle = X.mean(dim=0)\n",
    "    n = X.shape[0]\n",
    "    \n",
    "    # Parameters of the prior\n",
    "    mu_0 = prior_dist.mean\n",
    "    precision_0 = prior_dist.precision_matrix\n",
    "    \n",
    "    # Parameters of posterior\n",
    "    precision_n = n * precision_known + precision_0\n",
    "    mu_n = torch.matmul(n * torch.matmul(mu_mle.unsqueeze(0), precision_known) + torch.matmul(mu_0.unsqueeze(0), precision_0), torch.inverse(precision_n))\n",
    "    posterior_dist = MultivariateNormal(mu_n, precision_matrix=precision_n)\n",
    "    return posterior_dist"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "id": "8d0b3818",
   "metadata": {},
   "outputs": [],
   "source": [
    "# Let us assume that the true distribution is a normal distribution. The true distribution corresponds \n",
    "# to a single class.\n",
    "precision_known = torch.tensor([[0.1, 0], [0, 0.1]], dtype=torch.float)\n",
    "true_dist = MultivariateNormal(torch.tensor([20, 10], dtype=torch.float), precision_matrix=precision_known)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "id": "2ebed2cc",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True mean: tensor([20., 10.])\n",
      "MAP mean: tensor([[18.6117,  8.9122]])\n",
      "MLE mean: tensor([18.1845,  8.8156])\n"
     ]
    }
   ],
   "source": [
    "# Case 1\n",
    "# Let us assume our prior is a Normal distribution with a good estimate of the mean\n",
    "\n",
    "prior_mu = torch.tensor([19, 9], dtype=torch.float)\n",
    "prior_precision = torch.tensor([[0.33, 0], [0, 0.33]], dtype=torch.float)\n",
    "prior_dist = MultivariateNormal(prior_mu, precision_matrix=prior_precision)\n",
    "\n",
    "torch.manual_seed(0)\n",
    "                           \n",
    "#Number of samples is low. \n",
    "n = 3\n",
    "X = true_dist.sample((n,))\n",
    "posterior_dist_low_n = inference_known_precision(X, prior_dist, precision_known)\n",
    "\n",
    "mu_mle = X.mean(dim=0)\n",
    "mu_map = posterior_dist_low_n.mean\n",
    "\n",
    "\n",
    "# When n is low, the posterior is dominated by the prior. Thus, a good prior can help offset the lack of data.\n",
    "# We can see this in the following case. \n",
    "\n",
    "# With a small sample (n=3), the MLE estimate of mean is worse compared to the MAP estimate of mean\n",
    "\n",
    "print(f\"True mean: {true_dist.mean}\")\n",
    "print(f\"MAP mean: {mu_map}\")\n",
    "print(f\"MLE mean: {mu_mle}\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "id": "3da9b8e4",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "True mean: tensor([20., 10.])\n",
      "MAP mean: tensor([[19.9701, 10.1186]])\n",
      "MLE mean: tensor([19.9733, 10.1223])\n"
     ]
    }
   ],
   "source": [
    "# Case 2\n",
    "prior_mu = torch.tensor([19, 9], dtype=torch.float)\n",
    "prior_precision = torch.tensor([[0.33, 0], [0, 0.33]], dtype=torch.float)\n",
    "prior_dist = MultivariateNormal(prior_mu, precision_matrix=prior_precision)\n",
    "\n",
    "torch.manual_seed(0)\n",
    "                           \n",
    "#Number of samples is high. \n",
    "n = 1000\n",
    "X = true_dist.sample((n,))\n",
    "posterior_dist_high_n = inference_known_precision(X, prior_dist, precision_known)\n",
    "\n",
    "mu_mle = X.mean(dim=0)\n",
    "mu_map = posterior_dist_high_n.mean\n",
    "\n",
    "\n",
    "# When n is high, the MLE tends to converge to the true distribution. The MAP also tends to converge to the MLE, \n",
    "# and in turn converges to the true distribution\n",
    "\n",
    "print(f\"True mean: {true_dist.mean}\")\n",
    "print(f\"MAP mean: {mu_map}\")\n",
    "print(f\"MLE mean: {mu_mle}\")"
   ]
  },
  {
   "cell_type": "markdown",
   "id": "794ab2e9",
   "metadata": {},
   "source": [
    "### How to use the estimated mean parameter?\n",
    "\n",
    "We typically find $\\vec \\mu_{∗}$, the value of $\\vec\\mu$ that maximizes this posterior probability. In this particular case, the maxima of a Gaussian probability density occurs at the mean, hence, $\\vec\\mu_{∗}$ = $\\vec\\mu_{n}$.\n",
    "\n",
    "Given an arbitrary new data instance $x$, its probability of belonging to the class from which the training data has been sampled is $\\mathcal{N}\\left( \\vec x; \\vec\\mu_{n},  \\Lambda_{n} \\right)$."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "id": "fcd760d0",
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "MAP distribution mu: tensor([[19.9701, 10.1186]]) precision:tensor([[[0.1000, 0.0000],\n",
      "         [0.0000, 0.1000]]])\n"
     ]
    }
   ],
   "source": [
    "map_dist = MultivariateNormal(posterior_dist_high_n.mean, precision_matrix=precision_known)\n",
    "print(f\"MAP distribution mu: {map_dist.mean} precision:{map_dist.precision_matrix}\")"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 5
}
